# helper function
rmvnorm<-function(n,mu,Sigma,chol.Sigma=chol(Sigma)) {
  E<-matrix(rnorm(n*length(mu)),n,length(mu))
  t(  t(E%*%chol.Sigma) +c(mu))
}

# helper function
outputdata <- function(type, data){
  if(type=="character"){
    tab <- table(data[[names(type)]])
    mode <- names(tab)[which.max(tab)]
    return(factor(mode, levels=levels(as.factor(data[[names(type)]]))))
  }
  if(type=="factor"){
    tab <- table(data[,names(type)])
    mode <- names(tab)[which.max(tab)]
    return(factor(mode, levels=levels(data[[names(type)]])))
  }
  if(type=="numeric"|type=="integer"){
    return(median(data[[names(type)]]))
  }
}

# simulation function
simBetas <- function(parameters, nsims=100){
  simbetas <- list()
  for(i in 1:length(parameters)) {
    simbetas[[i]] <- do.call(rbind, lapply(parameters[[i]],
                                           function(x) rmvnorm(n=nsims, mu=x$est, Sigma=x$vcov)))
  }
  return(simbetas)
}

# produce cmatrix function
produce_cmatrix <- function(prep, covariate, method,cov.value1=NULL,
                            cov.value2=NULL, npoints=100, moderator=NULL, moderator.value=NULL){
  
  #Find type of each variable
  types <- lapply(prep$data, class)
  types <- unlist(types[prep$varlist]) 
  #switched the below out.
  #types <- sapply(prep$varlist,function (x) class(prep$data[,x]))
  
  #Make control matrix
  #What is the covariate of interest and what are the controls?
  covariateofinterest <- which(prep$varlist==covariate)
  controls <- which(prep$varlist!=covariate)
  
  if(method=="pointestimate"){    
    #Start cdata with the variable of interest
    if(types[covariateofinterest]=="character") cdata <- data.frame(factor(unique(prep$data[[covariate]])))
    
    if(types[covariateofinterest]=="factor") cdata <- data.frame(unique(prep$data[[covariate]]))
    
    if(types[covariateofinterest]=="numeric" |
       types[covariateofinterest]=="integer") cdata <-
        data.frame(unique(prep$data[[covariate]]))
    
    names(cdata) <- covariate
  }
  if(method=="difference"){
    if(types[covariateofinterest]=="character" |
       types[covariateofinterest]=="factor") {
      lev <- levels(as.factor(prep$data[[covariate]]))
      x <- c(as.character(cov.value1), as.character(cov.value2))
      cdata <- base::data.frame(factor(x,levels=lev))
      colnames(cdata) <- covariate 
      rm(x,lev)
    }
    if(types[covariateofinterest]=="numeric" |
       types[covariateofinterest]=="integer") cdata <-  base::data.frame(c(cov.value1, cov.value2))
    names(cdata) <- covariate
  }
  
  if(method=="continuous"){
    if(types[covariateofinterest]=="character" |
       types[covariateofinterest]=="factor")
      stop("Covariate of interest must be numeric")
    if(types[covariateofinterest]=="numeric" |
       types[covariateofinterest]=="integer") cdata <-
        base::data.frame(seq(min(prep$data[[covariate]]),max(prep$data[[covariate]]),
                             length.out=npoints))
    names(cdata) <- covariate
  }
  #Insert the values of the controls  
  if(length(controls)>0){
    for(i in 1:length(controls)){
      cdata[,prep$varlist[controls[i]]] <-outputdata(types[controls[i]], prep$data)
    }
  }
  
  #Insert the value for the interaction, if applicable
  if(!is.null(moderator) & !is.null(moderator.value)){
    if(is.factor(prep$data[[moderator]])|is.character(prep$data[[moderator]])){
      cdata[,moderator] <- factor(moderator.value, levels=levels(as.factor(prep$data[[moderator]])))
    }else{
      cdata[,moderator] <- moderator.value
    }
  }
  if(!is.null(moderator) & is.null(moderator.value)){
    stop("Please specify the value of the moderator")
  }
  if(is.null(moderator) & !is.null(moderator.value)){
    stop("Please specify the moderator")
  }
  #Reorder to reflect original data
  if(ncol(cdata)>1){
    cdata <- cdata[,names(prep$data)]
  }
  
  #Get model.matrix
  #cmatrix <- parseFormulas(prep, cdata)
  cmatrix <- makeDesignMatrix(prep$formula, prep$data, cdata, sparse=FALSE)
  return(list(cdata=cdata,cmatrix=cmatrix))
}

# prep the differences
prepDifference <- function(prep,covariate,topics, cdata, cmat, simbetas, offset,
                           cov.value1=NULL, cov.value2=NULL,...){
  #What are the unique values of the covariate we are going to plot over?
  uvals <- cdata[,covariate]
  
  #For each topic, 1. Simulate values, 2. Find means and cis
  means = list()
  cis = list()
  for(i in 1:length(topics)){
    #Simulate values
    sims <- cmat%*%t(simbetas[[which(prep$topics==topics[i])]])
    #Take difference
    diff <- sims[1,]-sims[2,]
    #Find means and cis
    means[[i]] <- mean(diff)
    cis[[i]] = quantile(diff, c(offset,1-offset))
  }
  df <- data.frame(topic = c(1:length(means)),
                   mu = as.numeric(means))
  df$lwr <- as.numeric(sapply(cis, function(x){
    x[1]
  }))
  df$upr <- as.numeric(sapply(cis, function(x){
    x[2]
  }))
  df$sig <- with(df, ifelse((lwr < 0 & upr < 0)|(lwr > 0 & upr > 0), 1, 0))
  
  return(df)
}

prepDifference.combine <- 
  function(prep,covariate,topics, cdata, cmat, simbetas, offset,
         cov.value1=NULL, cov.value2=NULL,...,
         combine = NULL){
  #What are the unique values of the covariate we are going to plot over?
  uvals <- cdata[,covariate]
  
  #For each topic, 1. Simulate values, 2. Find means and cis
  means = list()
  cis = list()
  
  if(any(combine != FALSE)){
    for(i in topics[-combine]){
      #Simulate values
      sims <- cmat%*%t(simbetas[[which(prep$topics==topics[i])]])
      #Take difference
      diff <- sims[1,]-sims[2,]
      #Find means and cis
      means <- c(means, list(mean(diff)))
      cis <- c(cis, list(quantile(diff, c(offset,1-offset))))
    }
    sims_comb <- cmat%*%t(simbetas[[which(prep$topics==topics[combine[1]])]])
    for(i in 2:length(combine)){
      sims_comb <- sims_comb + cmat%*%t(simbetas[[which(prep$topics==topics[combine[i]])]])
    }
    comb_diff <- sims_comb[1,]-sims_comb[2,]
    means <- c(means, list(mean(comb_diff)))
    cis <- c(cis, list(quantile(comb_diff, c(offset,1-offset))))
    
  }else{
    for(i in 1:length(topics)){
      #Simulate values
      sims <- cmat%*%t(simbetas[[which(prep$topics==topics[i])]])
      #Take difference
      diff <- sims[1,]-sims[2,]
      #Find means and cis
      means[[i]] <- mean(diff)
      cis[[i]] <- quantile(diff, c(offset,1-offset))
    }
  }
  df <- data.frame(topic = c(1:length(means)),
                   mu = as.numeric(means))
  df$lwr <- as.numeric(sapply(cis, function(x){
    x[1]
  }))
  df$upr <- as.numeric(sapply(cis, function(x){
    x[2]
  }))
  df$sig <- with(df, ifelse((lwr < 0 & upr < 0)|(lwr > 0 & upr > 0), 1, 0))
  
  return(df)
}

# build basic ggplot
prep.plot <- function(df){
  sigs <- unique(df$sig)
  if(length(sigs) == 2){
    df %>% ggplot(aes(x = forcats::fct_inorder(factor(topic)), y = mu, ymin = lwr, ymax = upr, col = factor(sig)))+
      geom_pointrange()+
      scale_color_manual(name = "",
                         breaks = c(0,1),
                         values = c("grey","black"),
                         labels = c("ns","s"))+
      guides(color = FALSE)+
      coord_flip()+
      theme_bw()+
      theme(text=element_text(family="Times New Roman", size=16),
            plot.title = element_text(face = "bold", size = 20))
  }else{
    if(sigs == 1){
      df %>% ggplot(aes(x = factor(topic), y = mu, ymin = lwr, ymax = upr), col = "black")+
        geom_pointrange()+
        guides(color = FALSE)+
        coord_flip()+
        theme_bw()+
        theme(text=element_text(family="Times New Roman", size=16),
              plot.title = element_text(face = "bold", size = 20))
    }else{
      df %>% ggplot(aes(x = factor(topic), y = mu, ymin = lwr, ymax = upr), col = "grey")+
        geom_pointrange()+
        guides(color = FALSE)+
        coord_flip()+
        theme_bw()+
        theme(text=element_text(family="Times New Roman", size=16),
              plot.title = element_text(face = "bold", size = 20))
    }
  }
  
}

plotSTMdiff <- function (x, covariate, model = NULL, topics = x$topics, method = "difference", 
                            cov.value1 = NULL, cov.value2 = NULL, 
          moderator = NULL, moderator.value = NULL, npoints = 100, 
          nsims = 100, ci.level = 0.95, xlim = NULL, ylim = NULL, xlab = "", 
          ylab = NULL, main = "", printlegend = T, labeltype = "numbers", 
          n = 7, frexw = 0.5, add = F, linecol = NULL, width = 25, 
          verbose.labels = T, family = NULL, custom.labels = NULL, 
          omit.plot = FALSE, ...) 
{
  method <- match.arg(method)
  if (method == "difference" && (is.null(cov.value1) | is.null(cov.value2))) {
    stop("For method='difference' both cov.value1 and cov.value2 must be specified.")
  }
  cthis <- produce_cmatrix(prep = x, covariate = covariate, 
                           method = method, cov.value1 = cov.value1, cov.value2 = cov.value2, 
                           npoints = npoints, moderator = moderator, moderator.value = moderator.value)
  cdata <- cthis$cdata
  cmat <- cthis$cmatrix
  simbetas <- simBetas(x$parameters, nsims = nsims)
  offset <- (1 - ci.level)/2
  if (method == "continuous") {
    toreturn <- plotContinuous(prep = x, covariate = covariate, 
                               topics = topics, cdata = cdata, cmat = cmat, simbetas = simbetas, 
                               offset = offset, xlab = xlab, ylab = ylab, main = main, 
                               xlim = xlim, ylim = ylim, linecol = linecol, add = add, 
                               labeltype = labeltype, n = n, custom.labels = custom.labels, 
                               model = model, frexw = frexw, printlegend = printlegend, 
                               omit.plot = omit.plot, ...)
    return(invisible(toreturn))
  }
  if (method == "pointestimate") {
    toreturn <- plotPointEstimate(prep = x, covariate = covariate, 
                                  topics = topics, cdata = cdata, cmat = cmat, simbetas = simbetas, 
                                  offset = offset, xlab = xlab, ylab = ylab, main = main, 
                                  xlim = xlim, ylim = ylim, linecol = linecol, add = add, 
                                  labeltype = labeltype, n = n, custom.labels = custom.labels, 
                                  model = model, frexw = frexw, width = width, verbose.labels = verbose.labels, 
                                  omit.plot = omit.plot, ...)
    return(invisible(toreturn))
  }
  if (method == "difference") {
    if (missing(cov.value1)) 
      stop("Missing a value for cov.value1. See documentation.")
    if (missing(cov.value2)) 
      stop("Missing a value for cov.value2. See documentation.")
    dat <- prepDifference(prep = x, covariate = covariate, 
                               topics = topics, cdata = cdata, cmat = cmat, simbetas = simbetas, 
                               offset = offset, 
                               n = n, 
                               model = model, cov.value1 = cov.value1, 
                               cov.value2 = cov.value2)
    return(prep.plot(dat))
  }
}

plotSTMdiff.flex <- function (x, covariate, model = NULL, topics = x$topics, method = "difference", 
                              flex = "plot", combine = FALSE,
                         cov.value1 = NULL, cov.value2 = NULL, 
                         moderator = NULL, moderator.value = NULL, npoints = 100, 
                         nsims = 100, ci.level = 0.95, xlim = NULL, ylim = NULL, xlab = "", 
                         ylab = NULL, main = "", printlegend = T, labeltype = "numbers", 
                         n = 7, frexw = 0.5, add = F, linecol = NULL, width = 25, 
                         verbose.labels = T, family = NULL, custom.labels = NULL, 
                         omit.plot = FALSE, ...) 
{
  method <- match.arg(method)
  if (method == "difference" && (is.null(cov.value1) | is.null(cov.value2))) {
    stop("For method='difference' both cov.value1 and cov.value2 must be specified.")
  }
  cthis <- produce_cmatrix(prep = x, covariate = covariate, 
                           method = method, cov.value1 = cov.value1, cov.value2 = cov.value2, 
                           npoints = npoints, moderator = moderator, moderator.value = moderator.value)
  cdata <- cthis$cdata
  cmat <- cthis$cmatrix
  simbetas <- simBetas(x$parameters, nsims = nsims)
  offset <- (1 - ci.level)/2
  if (method == "continuous") {
    toreturn <- plotContinuous(prep = x, covariate = covariate, 
                               topics = topics, cdata = cdata, cmat = cmat, simbetas = simbetas, 
                               offset = offset, xlab = xlab, ylab = ylab, main = main, 
                               xlim = xlim, ylim = ylim, linecol = linecol, add = add, 
                               labeltype = labeltype, n = n, custom.labels = custom.labels, 
                               model = model, frexw = frexw, printlegend = printlegend, 
                               omit.plot = omit.plot, ...)
    return(invisible(toreturn))
  }
  if (method == "pointestimate") {
    toreturn <- plotPointEstimate(prep = x, covariate = covariate, 
                                  topics = topics, cdata = cdata, cmat = cmat, simbetas = simbetas, 
                                  offset = offset, xlab = xlab, ylab = ylab, main = main, 
                                  xlim = xlim, ylim = ylim, linecol = linecol, add = add, 
                                  labeltype = labeltype, n = n, custom.labels = custom.labels, 
                                  model = model, frexw = frexw, width = width, verbose.labels = verbose.labels, 
                                  omit.plot = omit.plot, ...)
    return(invisible(toreturn))
  }
  if (method == "difference") {
    if (missing(cov.value1)) 
      stop("Missing a value for cov.value1. See documentation.")
    if (missing(cov.value2)) 
      stop("Missing a value for cov.value2. See documentation.")
    if(any(combine != FALSE)){
      dat <- prepDifference.combine(prep = x, covariate = covariate, 
                            topics = topics, cdata = cdata, cmat = cmat, simbetas = simbetas, 
                            offset = offset, 
                            n = n, 
                            model = model, cov.value1 = cov.value1, 
                            cov.value2 = cov.value2,
                            combine = combine)
    }else{
      dat <- prepDifference(prep = x, covariate = covariate, 
                            topics = topics, cdata = cdata, cmat = cmat, simbetas = simbetas, 
                            offset = offset, 
                            n = n, 
                            model = model, cov.value1 = cov.value1, 
                            cov.value2 = cov.value2)
    }
    
    if(flex == "plot"){
      return(prep.plot(dat))
    }else{
      return(dat)
    }
  }
}
